set threshold resnet18 pytorch

58

set threshold resnet18 pytorch -

from PIL import Image
import torch
from torchvision import transforms
input_image = Image.open("img/danbooru_resnet1.png") # load an image of your choice
preprocess = transforms.Compose([
    transforms.Resize(360),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)

# The output has unnormalized scores. To get probabilities, you can run a sigmoid on it.
probs = torch.sigmoid(output[0]) # Tensor of shape 6000, with confidence scores over Danbooru's top 6000 tags

# Second part of the code 
# to set a threshold value and to use it 
tmp = probs[probs > thresh]
inds = probs.argsort(descending=True)

Comments

Submit
0 Comments